import data_collection
import train
import numpy as np
from termcolor import colored
import torch
import datetime
import decoder
import copy
import policy


if __name__ == "__main__":

    # Experiment Set-up
    noise = float(input('Noise Level (Enter a float between 0-1): '))
    if noise > 0:
        noise_type = input('Noise Type (Action_Inclusive/Context_Inclusive/Action_Context_Inclusive/Independent): ')
    else:
        noise_type = 'No'
    decoder_type = input('Decoder Type (full: (X,A,Y)-dependent/partial: Y-dependent): ')
    div_measure = input('Divergence measures (KL/PEARSON/PEARSON_KL): ')
    gamma = float(input('gamma (= 1 / beta in Objective (4)); to set beta = 0, input any gamma > 10'))

    np.random.seed(321)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    num_action = 10
    num_train_data = 60000
    num_test_data = 10000
    sub_sample = 5
    batch = 600
    num_update = 500
    eval_every = 500
    T_lr = 16e-5
    d_lr = 16e-5
    p_batch = 600
    # Training Parameters 1 (Reward Decoder)
    epoch_vib = num_update * 2
    T_w_decay = 0.3
    d_w_decay = 0.8
    # Training Parameters 2 (Policy)
    p_epoch = 8000
    p_lr = 1e-4

    print(colored('Start_Time: {}, Device: {}, {}-{}_Noise, '
                  'Training_data_size: {}, Test_data_size: {}'
                  .format(datetime.datetime.now(), device, noise, noise_type,
                          num_train_data, num_test_data), 'blue'))
    print("")
    print(colored('[f-VI-IGL Config] VIB-Decoder-Type: {}, gamma: {}, Epochs: {}, Batch_size: {}, Sub_Sample: {}, '
                  'MINE_lr: {}, T_wd: {}, Dec_lr: {}, Dec_wd: {}, Policy_lr: {}, '
                  'p_batch: {}, DIV_measure: {}'
                  .format(decoder_type, gamma, epoch_vib, batch, sub_sample,
                          T_lr, T_w_decay, d_lr, d_w_decay, p_lr,
                          p_batch, div_measure), 'blue'))

    # data-collection
    D = data_collection.data_collection(num_data=num_train_data,
                                        device=device,
                                        noise=noise,
                                        noise_type=noise_type)
    D_test = data_collection.test_data_collection(num_data=num_test_data,
                                                  device=device,
                                                  noise=noise,
                                                  noise_type=noise_type)

    # results are averaged over 20 times
    vib_value_curve = np.zeros(p_epoch // (2 * eval_every) + 1)
    vib_value = []
    vib_decoder_curve = np.zeros(num_update // eval_every + 1)

    for i in range(16):
        print("")

        # Initialization
        r_decoder = decoder.RewardDecoder(type=decoder_type, num_action=num_action).to(device)
        r_decoder_copy = copy.deepcopy(r_decoder)
        IGL_policy = policy.Policy(num_action=num_action).to(device)
        IGL_policy_copy = copy.deepcopy(IGL_policy)

        # Training f-VI-IGL
        res1, res2 = train.train_vib_igl(device=device,
                                         iteration=i + 1,
                                         test_each=eval_every,
                                         dataset=D,
                                         test_data=D_test,

                                         num_epochs=epoch_vib,
                                         num_batch=batch,
                                         sub_sample_num=sub_sample,
                                         gamma=gamma,
                                         T_lr=T_lr,
                                         T_wd=T_w_decay,
                                         f_measure=div_measure,

                                         r_decoder=r_decoder,
                                         d_lr=d_lr,
                                         d_wd=d_w_decay,

                                         igl_policy=IGL_policy,
                                         p_epoch=p_epoch,
                                         p_lr=p_lr,
                                         num_action=num_action,
                                         p_batch=p_batch)
        vib_value_curve += res1
        vib_decoder_curve += res2
        vib_value.append(res1[-1])
        print(colored('VIB-f-IGL: {}-{}-{}-{}-{}, Policy_value: {}, std: {}, Decoder_MSE: {}'
                      .format(decoder_type, gamma, noise, noise_type, div_measure,
                              vib_value_curve / (i + 1),
                              np.std(np.array(vib_value)),
                              vib_decoder_curve / (i + 1)),
                      'green'))
        print("")
